#
# This file contains the basic build block of Inception Network as defined in:
#
#   https://arxiv.org/pdf/1512.00567.pdf
#
# and in Tensorflow implementation
#

#
# Convolution layer with Batch Normalization and Rectifier Linear activation.
#
ConvBNReLULayer {numOutputChannels, filterShape, stride, pad = true, bnTimeConst = 4096} = Sequential(
    ConvolutionalLayer {numOutputChannels, filterShape, init = "glorotUniform", stride = stride, pad = pad, bias = false} :
    BatchNormalizationLayer {spatialRank = 2, normalizationTimeConstant = bnTimeConst, useCntkEngine = false} :
    ReLU
)

#
# Figure 5 from https://arxiv.org/pdf/1512.00567.pdf
# Modified with the added 5x5 branch to match Tensorflow implementation
#
InceptionBlock1{numOutputChannels1x1, numOutputChannels5x5, numOutputChannels3x3_3x3, numOutputChannelsPool, bnTimeConst} =
{
    apply(x) = {
        # 1x1 Convolution
        branch1x1 = ConvBNReLULayer{numOutputChannels1x1, (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst}(x)

        # 5x5 Convolution
        branch5x5 = Sequential(
            ConvBNReLULayer{numOutputChannels5x5[0], (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels5x5[1], (5:5), (1:1), pad = true, bnTimeConst = bnTimeConst}
        )(x)

        # 3x3 3x3 Convolution
        branch3x3_3x3 = Sequential(
            ConvBNReLULayer{numOutputChannels3x3_3x3[0], (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels3x3_3x3[1], (3:3), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels3x3_3x3[2], (3:3), (1:1), pad = true, bnTimeConst = bnTimeConst}
        )(x)

        # Average pooling
        branch_pool = Sequential(
            AveragePoolingLayer{(3:3), stride = (1:1), pad = true} :
            ConvBNReLULayer{numOutputChannelsPool, (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst}
        )(x)

        out = Splice((branch1x1:branch5x5:branch3x3_3x3:branch_pool), axis=3)
    }.out
}.apply

InceptionBlock2{numOutputChannels3x3, numOutputChannels3x3_3x3, bnTimeConst} =
{
    apply(x) = {
        # 3x3 Convolution
        branch3x3 = ConvBNReLULayer{numOutputChannels3x3, (3:3), (2:2), pad = false, bnTimeConst = bnTimeConst}(x)

        # 3x3 3x3 Convolution
        branch3x3_3x3 = Sequential(
            ConvBNReLULayer{numOutputChannels3x3_3x3[0], (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels3x3_3x3[1], (3:3), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels3x3_3x3[2], (3:3), (2:2), pad = false, bnTimeConst = bnTimeConst}
        )(x)

        # Max pooling
        branch_pool = MaxPoolingLayer{(3:3), stride = (2:2), pad = false}(x)

        out = Splice((branch3x3:branch3x3_3x3:branch_pool), axis=3)
    }.out
}.apply

#
# Figure 6 from https://arxiv.org/pdf/1512.00567.pdf
#
InceptionBlock3{numOutputChannels1x1, numOutputChannels7x7, numOutputChannels7x7_7x7, numOutputChannelsPool, bnTimeConst} =
{
    apply(x) = {
        # 1x1 Convolution
        branch1x1 = ConvBNReLULayer{numOutputChannels1x1, (1:1), (1:1), pad=true, bnTimeConst = bnTimeConst}(x)

        # 7x7 Convolution
        branch7x7 = Sequential(
            ConvBNReLULayer{numOutputChannels7x7[0], (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels7x7[1], (1:7), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels7x7[2], (7:1), (1:1), pad = true, bnTimeConst = bnTimeConst}
        )(x)

        # 7x7 7x7 Convolution
        branch7x7_7x7 = Sequential(
            ConvBNReLULayer{numOutputChannels7x7_7x7[0], (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels7x7_7x7[1], (7:1), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels7x7_7x7[2], (1:7), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels7x7_7x7[3], (7:1), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels7x7_7x7[4], (1:7), (1:1), pad = true, bnTimeConst = bnTimeConst}
        )(x)

        # Average pooling
        branch_pool = Sequential(
            AveragePoolingLayer{(3:3), stride = (1:1), pad = true} :
            ConvBNReLULayer{numOutputChannelsPool, (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst}
        )(x)

        out = Splice((branch1x1:branch7x7:branch7x7_7x7:branch_pool), axis=3)
    }.out
}.apply

InceptionBlock4{numOutputChannels3x3, numOutputChannels7x7x3, bnTimeConst} =
{
    apply(x) = {
        # 3x3 Convolution
        branch3x3 = Sequential(
            ConvBNReLULayer{numOutputChannels3x3[0], (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels3x3[1], (3:3), (2:2), pad = false, bnTimeConst = bnTimeConst}
        )(x)

        # 7x7x3 Convolution
        branch7x7x3 = Sequential(
            ConvBNReLULayer{numOutputChannels7x7x3[0], (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels7x7x3[1], (1:7), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels7x7x3[2], (7:1), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels7x7x3[3], (3:3), (2:2), pad = false, bnTimeConst = bnTimeConst}
        )(x)

        # Max pooling
        branch_pool = MaxPoolingLayer{(3:3), stride = (2:2), pad = false}(x)

        out = Splice((branch3x3:branch7x7x3:branch_pool), axis=3)
    }.out
}.apply

#
# Figure 7 from https://arxiv.org/pdf/1512.00567.pdf
#
InceptionBlock5{numOutputChannels1x1, numOutputChannels3x3, numOutputChannels3x3_3x3, numOutputChannelsPool, bnTimeConst} =
{
    apply(x) = {
        # 1x1 Convolution
        branch1x1 = ConvBNReLULayer{numOutputChannels1x1, (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst}(x)

        # 3x3 Convolution
        branch3x3_i = ConvBNReLULayer{numOutputChannels3x3[0], (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst}(x)
        branch3x3_1 = ConvBNReLULayer{numOutputChannels3x3[1], (1:3), (1:1), pad = true, bnTimeConst = bnTimeConst}(branch3x3_i)
        branch3x3_2 = ConvBNReLULayer{numOutputChannels3x3[2], (3:1), (1:1), pad = true, bnTimeConst = bnTimeConst}(branch3x3_i)
        branch3x3   = Splice((branch3x3_1:branch3x3_2), axis=3)

        # 3x3 3x3 Convolution
        branch3x3_3x3_i = Sequential(
            ConvBNReLULayer{numOutputChannels3x3_3x3[0], (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst} :
            ConvBNReLULayer{numOutputChannels3x3_3x3[1], (3:3), (1:1), pad = true, bnTimeConst = bnTimeConst}
        )(x)

        branch3x3_3x3_1 = ConvBNReLULayer{numOutputChannels3x3_3x3[2], (1:3), (1:1), pad = true, bnTimeConst = bnTimeConst}(branch3x3_3x3_i)
        branch3x3_3x3_2 = ConvBNReLULayer{numOutputChannels3x3_3x3[3], (3:1), (1:1), pad = true, bnTimeConst = bnTimeConst}(branch3x3_3x3_i)
        branch3x3_3x3   = Splice((branch3x3_3x3_1:branch3x3_3x3_2), axis=3)

        # Average pooling
        branch_pool = Sequential(
            AveragePoolingLayer{(3:3), stride = (1:1), pad = true} :
            ConvBNReLULayer{numOutputChannelsPool, (1:1), (1:1), pad = true, bnTimeConst = bnTimeConst}
        )(x)

        out = Splice((branch1x1:branch3x3:branch3x3_3x3:branch_pool), axis=3)
    }.out
}.apply
